-
Notifications
You must be signed in to change notification settings - Fork 74
Use meta device tensor to infer contiguity for expr-eval segments #5772
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: resetContiguityFromTensor
Are you sure you want to change the base?
Conversation
|
Review updated until commit c81f895 Description
|
| Relevant files | |||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Configuration changes | |||||||||||||||||||||||||
| Enhancement | 6 files
| ||||||||||||||||||||||||
| Bug fix | 1 files
| ||||||||||||||||||||||||
| Miscellaneous | 1 files
| ||||||||||||||||||||||||
| Tests | 12 files
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Backward Compatibility
InferContiguity option changes the behavior of MatmulOp::evaluate(). When disabled, it uses the old logic that assumes contiguous outputs. When enabled, it uses the new logic that infers actual contiguity. This could potentially break existing code that depends on the old behavior. The PR should document this breaking change clearly and provide migration guidance. |
…andling - Renamed `inferOutputShapeAndContiguousStrides` to `inferContiguousOutputMetaTensor` for clarity. - Updated function signatures to remove unnecessary parameters. - Introduced `inferOutputMetaTensor` in `FusionKernelRuntime` to handle output shape inference for segmented groups. - Enhanced `updateWithSegmentOutputs` to streamline output management without updating contiguity directly. - Improved overall code organization and readability.
|
!test |
|
!test |
|
!test |
|
!test |
|
!test |
|
!test |
|
!test |
Greptile SummaryThis PR fixes issue #4888 where
The implementation leverages PyTorch's meta device to compute shapes/strides without materializing actual tensors. The PR description mentions that some ATen ops' meta device implementations are Python-based and can hang when called from C++ (due to GIL acquisition issues), requiring manual shape/stride computation using Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant FusionKernelRuntime
participant prepareInputs/getMaybeHeuristicsFor
participant inferOutputMetaTensor
participant ExpressionEvaluator
participant ATen as ATen (Meta Device)
participant updateContiguityOfSegmentOutputs
participant TensorView
User->>FusionKernelRuntime: runWithInputs(args)
FusionKernelRuntime->>prepareInputs/getMaybeHeuristicsFor: Prepare segment inputs
loop For each segment
prepareInputs/getMaybeHeuristicsFor->>inferOutputMetaTensor: Infer output shape/stride
alt is_expr_eval && InferContiguity enabled
inferOutputMetaTensor->>ExpressionEvaluator: Create ExpressionEvaluator
loop For each input
inferOutputMetaTensor->>ATen: at::empty_strided(sizes, strides, device=meta)
ATen-->>inferOutputMetaTensor: meta tensor
inferOutputMetaTensor->>ExpressionEvaluator: bind(input, meta_tensor)
end
loop For each output
ExpressionEvaluator->>ATen: evaluate() - run ATen ops on meta device
ATen-->>ExpressionEvaluator: result meta tensor with actual strides
ExpressionEvaluator-->>inferOutputMetaTensor: result
end
else not expr_eval or InferContiguity disabled
inferOutputMetaTensor->>inferOutputMetaTensor: inferContiguousOutputMetaTensor()
Note right of inferOutputMetaTensor: Assumes contiguous output
end
inferOutputMetaTensor-->>prepareInputs/getMaybeHeuristicsFor: group_runtime_outputs
prepareInputs/getMaybeHeuristicsFor->>updateContiguityOfSegmentOutputs: Update TensorView contiguity
alt InferContiguity enabled
loop For each output TensorView
updateContiguityOfSegmentOutputs->>TensorView: ir_utils::resetContiguityFromTensor(tv, tensor)
Note right of TensorView: Updates contiguity info from actual tensor strides
end
end
updateContiguityOfSegmentOutputs-->>prepareInputs/getMaybeHeuristicsFor: done
end
prepareInputs/getMaybeHeuristicsFor-->>FusionKernelRuntime: all_runtime_inputs prepared
FusionKernelRuntime->>FusionKernelRuntime: Execute segments with correct stride info
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
20 files reviewed, 1 comment
| auto fusion_to_run = segmented_fusion_->makeFusion(group_to_run).second; | ||
| auto group_runtime_outputs = inferOutputShapeAndContiguousStrides( | ||
| fusion_to_run.get(), group_runtime_inputs); | ||
| auto group_runtime_outputs = inferOutputMetaTensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm losing track of the code. group_runtime_inputs contain meta tensors or real tensors at this moment? The setDeviceIndex call seems to say they are real tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC in prepareInputs, group_runtime_inputs contains real tensor (but still, inferOutputShapeAndContiguousStrides returns meta tensor), but in getMaybeHeuristicsFor, group_runtime_inputs contains meta tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. Should setDeviceIndex at line 419 be removed? Is it safe or necessary? (I don't think your PR changes the situation; just OOC).
Co-authored-by: Jingyue Wu <[email protected]>
|
!test |
|
!test |
2 similar comments
|
!test |
|
!test |
|
!test |
1 similar comment
|
!test |
| auto fusion_to_run = segmented_fusion_->makeFusion(group_to_run).second; | ||
| auto group_runtime_outputs = inferOutputShapeAndContiguousStrides( | ||
| fusion_to_run.get(), group_runtime_inputs); | ||
| auto group_runtime_outputs = inferOutputMetaTensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. Should setDeviceIndex at line 419 be removed? Is it safe or necessary? (I don't think your PR changes the situation; just OOC).
| args_manager.updateWithSegmentOutputs( | ||
| group_to_run->outputs(), group_runtime_outputs, run_order_id); | ||
|
|
||
| updateContiguityOfSegmentOutputs(group_to_run, group_runtime_outputs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this to hide some bugs in mark_aliases_prepare or allocation_order_inference? The TensorViews in the complete fusion and therefore in segments ought to be correct after preseg.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do you define "hide a bug"? We need the correct continuity eventually, which is only possible after we know the scheduler of segmentation. So, why isn't this just writing the correct information, instead of hiding a bug?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
which is only possible after we know the scheduler of segmentation
But scheduling happens after prepareInputs:
Fuser/csrc/runtime/fusion_kernel_runtime.cpp
Line 431 in 352dcbf
| compileKernel(group_runtime_inputs, group_to_run); |
I'm probably missing some important details that are so obvious to you. Let me try to remove this line and see where things break...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
$ _bn && pytest tests/python/direct/test_python_frontend.py -k test_issue4888 -vs passes with the following patch
diff --git a/csrc/runtime/fusion_kernel_runtime.cpp b/csrc/runtime/fusion_kernel_runtime.cpp
index e025d29d..132cba82 100644
--- a/csrc/runtime/fusion_kernel_runtime.cpp
+++ b/csrc/runtime/fusion_kernel_runtime.cpp
@@ -427,8 +427,6 @@ std::vector<KernelArgumentHolder> FusionKernelRuntime::prepareInputs(
// map output args to tensor map
args_manager.updateWithSegmentOutputs(
group_to_run->outputs(), group_runtime_outputs, run_order_id);
-
- updateContiguityOfSegmentOutputs(group_to_run, group_runtime_outputs);
}
return all_runtime_inputs;But let me try other tests as well...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I missed the other call to updateContiguityOfSegmentOutputs. After removing that, I see SegmentationTest.RevertPrivatizedUpcast fails. Let me try to understand the error...
$ bin/test_nvfuser --gtest_filter=SegmentationTest.RevertPrivatizedUpcast
Running main() from /opt/pytorch/nvfuser/third_party/googletest/googletest/src/gtest_main.cc
Note: Google Test filter = SegmentationTest.RevertPrivatizedUpcast
[==========] Running 1 test from 1 test suite.
[----------] Global test environment set-up.
[----------] 1 test from SegmentationTest
[ RUN ] SegmentationTest.RevertPrivatizedUpcast
/opt/pytorch/nvfuser/tests/cpp/test_segmentation.cpp:855: Failure
Expected equality of these values:
num_upcast_ops
Which is: 1
2
To reproduce: NVFUSER_TEST_RANDOM_SEED=1768609993 NVFUSER_TEST_ATEN_RANDOM_SEED=0 test_nvfuser --gtest_filter='SegmentationTest.RevertPrivatizedUpcast'
[ FAILED ] SegmentationTest.RevertPrivatizedUpcast (218 ms)
[----------] 1 test from SegmentationTest (218 ms total)
[----------] Global test environment tear-down
[==========] 1 test from 1 test suite ran. (218 ms total)
[ PASSED ] 0 tests.
[ FAILED ] 1 test, listed below:
[ FAILED ] SegmentationTest.RevertPrivatizedUpcast
1 FAILED TEST
Fixes #4888
Stacked on #5766
I used to work on #5082 for the fix, but I hit too many blockers, because this PR could interact with many new assumptions/hacks/unfinalized designs on things like allocation domain, stream-sharded tensor, multidevice, etc., and we keep having new things committed to the main branch that break #5082. This situation delayed the PR for a very long time. So I recreated this PR that is more friendly to incremental development.
Today, in the main branch, in
FusionExecutorCache, we were assuming fusion segments always generate contiguous tensors. This is not true forExpressionEvaluatorsegments. For example, ATen's slice op returns non-contiguous tensors. It is worth mentioning that, because segmentation and scheduler selection depend on inputs, the contiguity of intermediate results also depends on inputs.This PR adds
FusionKernelRuntime::inferOutputMetaTensor(, which replacesinferOutputShapeAndContiguousStridesto infer the output shape and stride of each segment. BothFusionKernelRuntime::inferOutputMetaTensor(andinferOutputShapeAndContiguousStridesstore their result as a tensor on the meta device. The difference is,FusionKernelRuntime::inferOutputMetaTensor(will actually run the segment on device type meta if this segment is scheduled to run byExpressionEvaluator, whileinferOutputShapeAndContiguousStridesjust assumes the output to be contiguous.Because
FusionKernelRuntime::inferOutputMetaTensor(will run the segment on device type meta, related op'sMyOp::evaluateshould work for device type meta. There is good and bad news for this design. The good news is, mostMyOp::evaluatejust callsat::ops, which usually already support meta device, and PyTorch designed meta device to try to make its behavior on par with CUDA. The bad news is, because many op's meta device implementation is on Python, runningat::opon these kinds of ops would hang due to the inability to grab Python's GIL (Thanks @naoyam for help debugging!). If this is the case, the correspondingMyOp::evaluatemust manually compute the shape and stride and useat::empty_strided(device=meta)to create the result.Besides
FusionKernelRuntime::inferOutputMetaTensor(, this PR also addsFusionKernelRuntime::updateContiguityOfSegmentOutputs(. Which updates the segment outputTensorViews' contiguity based on the inferred shape and stride.This PR adds an enable option "infer-contiguity" to incrementally enable this feature. When "infer-contiguity" is disabled,
FusionKernelRuntime::inferOutputMetaTensor(will fallback to the behavior ofinferOutputShapeAndContiguousStrides, andFusionKernelRuntime::updateContiguityOfSegmentOutputs(will be no-op. The plan is, we merge this PR and not set "infer-contiguity" for the currently failed tests. I will write new PRs fixing the failed tests one by one.